package com.test.plugin;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
import org.apache.commons.lang.StringUtils;
import org.apache.derby.iapi.types.SQLVarbit;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.sql.Connection;
import java.util.List;
import java.util.Properties;

/**
 * 租户插件
 * @Author liuzijian
 * 参考 https://www.cnblogs.com/yuananyun/p/8093853.html 他这个有BUG
 */
@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})
})
public class TenantHelper implements Interceptor {

    private String fieldName = "tenant_id";//租户id字段
    private String dbType = "mysql";//数据库方言  db2 | h2 | hive | mysql | odps | oracle | phoenix | postgresql | sqlserver

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        String fieldValue = "123";//租户id  后期改成ThreadLocal获取

        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        List<SQLStatement> sqlStatementList = SQLUtils.parseStatements(boundSql.getSql(), dbType);
        SQLStatement sqlStatement = sqlStatementList.get(0);
        this.routing(sqlStatement, fieldValue);
        MetaObject metaObject = SystemMetaObject.forObject(boundSql);
        metaObject.setValue("sql", SQLUtils.toSQLString(sqlStatementList, dbType));
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);//Mybatis提供的插件代理生成方法
    }

    @Override
    public void setProperties(Properties properties) {
        String dbType = properties.getProperty("dialect");
    }

    /**
     * 根据sql类型路由 SELECT | INSERT | UPDATE | DELETE
     *
     * @param sqlStatement
     * @param fieldValue
     */
    private void routing(SQLStatement sqlStatement, String fieldValue) {
        if (sqlStatement instanceof SQLSelectStatement) {//查询语句
            SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) sqlStatement;
            SQLSelectQuery sqlSelectQuery = sqlSelectStatement.getSelect().getQuery();
            this.doSelect(sqlSelectQuery, fieldValue);
        } else if (sqlStatement instanceof SQLInsertStatement) {//插入语句
            SQLInsertStatement sqlInsertStatement = (SQLInsertStatement) sqlStatement;
            this.addInsertStatementCondition(sqlInsertStatement, fieldValue);
        } else if (sqlStatement instanceof SQLUpdateStatement) {//更新语句
            SQLUpdateStatement sqlUpdateStatement = (SQLUpdateStatement) sqlStatement;
            this.addUpdateStatementCondition(sqlUpdateStatement, fieldValue);
        } else if (sqlStatement instanceof SQLDeleteStatement) {//删除语句
            SQLDeleteStatement sqlDeleteStatement = (SQLDeleteStatement) sqlStatement;
            this.addDeleteStatementCondition(sqlDeleteStatement, fieldValue);
        }
    }

    /**
     * 处理没有union 或 union all查询 和 有 union 或 union all查询的语句
     *
     * @param query
     * @param fieldValue
     */
    private void doSelect(SQLSelectQuery query, String fieldValue) {
        if (query instanceof SQLSelectQueryBlock) {//处理没有union 或 union all关键字的查询
            this.addSelectStatementCondition((SQLSelectQueryBlock) query, ((SQLSelectQueryBlock) query).getFrom(), fieldValue);
        } else if (query instanceof SQLUnionQuery) {//处理union 或 uniln all的语句
            SQLSelectQuery left = ((SQLUnionQuery) query).getLeft();
            SQLSelectQuery right = ((SQLUnionQuery) query).getRight();
            this.doSelect(left, fieldValue);
            this.doSelect(right, fieldValue);
        }
    }

    /**
     * 处理select语句添加
     *
     * @param sqlSelectQueryBlock
     * @param from
     * @param fieldValue
     */
    private void addSelectStatementCondition(SQLSelectQueryBlock sqlSelectQueryBlock, SQLTableSource from, String fieldValue) {
        SQLExpr originCondition = sqlSelectQueryBlock.getWhere();
        if (from instanceof SQLExprTableSource) {//处理单表查询语句
            String tableName = ((SQLIdentifierExpr) ((SQLExprTableSource) from).getExpr()).getName();//获取表名
            String alias = from.getAlias();//获取别名
            String reallyFieldName = StringUtils.isBlank(alias) ? fieldName : alias + "." + fieldName;
            SQLExpr newSqlExpr = new SQLBinaryOpExpr(new SQLIdentifierExpr(reallyFieldName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
            sqlSelectQueryBlock.setWhere(SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, newSqlExpr, false, originCondition));//left == false固定往右边加
        } else if (from instanceof SQLJoinTableSource) {//处理表连接的语句
            SQLJoinTableSource sqlJoinTableSource = (SQLJoinTableSource) from;
            SQLTableSource left = sqlJoinTableSource.getLeft();
            SQLTableSource right = sqlJoinTableSource.getRight();
            this.addSelectStatementCondition(sqlSelectQueryBlock, left, fieldValue);
            this.addSelectStatementCondition(sqlSelectQueryBlock, right, fieldValue);
        } else if (from instanceof SQLSubqueryTableSource) {//处理INNER JOIN (SELECT id from role) c ON c.id = a.id 这种子查询语句
            SQLSubqueryTableSource sqlSubqueryTableSource = (SQLSubqueryTableSource) from;
            SQLSelect sqlSelect = sqlSubqueryTableSource.getSelect();
            this.doSelect(sqlSelect.getQuery(), fieldValue);
        } else {
            //TODO 暂无
        }
    }

    /**
     * 为insert语句添加租户字段
     *
     * @param sqlInsertStatement
     * @param fieldValue
     */
    private void addInsertStatementCondition(SQLInsertStatement sqlInsertStatement, String fieldValue) {
        List<SQLExpr> sqlExprList = sqlInsertStatement.getColumns();
        boolean has = false;
        for (SQLExpr sqlExpr : sqlExprList) {
            MetaObject metaObject = SystemMetaObject.forObject(sqlExpr);
            if (String.valueOf(metaObject.getValue("name")).equals(fieldName)) {
                has = true;
            }
        }
        if (!has) {
            sqlExprList.add(new SQLIdentifierExpr(fieldName));
            if (sqlInsertStatement.getValues() != null) {
                sqlInsertStatement.getValues().addValue(new SQLVariantRefExpr(fieldValue));
            }
            if (sqlInsertStatement.getQuery() != null) {//处理批量插入的sql
                SQLSelectQuery sqlSelect = sqlInsertStatement.getQuery().getQuery();
                this.doInsertSelect(sqlSelect, fieldValue);
            }
        }
    }

    /**
     * 处理批量插入的sql
     *
     * @param query
     * @param fieldValue
     */
    private void doInsertSelect(SQLSelectQuery query, String fieldValue) {
        if (query instanceof SQLSelectQueryBlock) {//处理没有union 或 union all关键字的查询
            SQLSelectQueryBlock sqlSelectQueryBlock = (SQLSelectQueryBlock) query;
            sqlSelectQueryBlock.getSelectList().add(new SQLSelectItem(new SQLVariantRefExpr(fieldValue)));
        } else if (query instanceof SQLUnionQuery) {//处理union 或 uniln all的语句
            SQLSelectQuery left = ((SQLUnionQuery) query).getLeft();
            SQLSelectQuery right = ((SQLUnionQuery) query).getRight();
            this.doInsertSelect(left, fieldValue);
            this.doInsertSelect(right, fieldValue);
        }
    }

    /**
     * 处理删除语句
     *
     * @param sqlDeleteStatement
     * @param fieldValue
     */
    private void addDeleteStatementCondition(SQLDeleteStatement sqlDeleteStatement, String fieldValue) {
        String alias = sqlDeleteStatement.getAlias();//表别名
        SQLExpr originCondition = sqlDeleteStatement.getWhere();
        String reallyFieldName = StringUtils.isBlank(alias) ? fieldName : alias + "." + fieldName;
        SQLExpr newSqlExpr = new SQLBinaryOpExpr(new SQLIdentifierExpr(reallyFieldName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
        if (sqlDeleteStatement.getWhere() != null) {
            sqlDeleteStatement.setWhere(SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, newSqlExpr, false, originCondition));
        } else {
            sqlDeleteStatement.setWhere(newSqlExpr);
        }
    }

    /**
     * 处理更新语句
     * TODO 暂时没做批量更新的逻辑
     * @param sqlUpdateStatement
     * @param fieldValue
     */
    private void addUpdateStatementCondition(SQLUpdateStatement sqlUpdateStatement, String fieldValue) {
        String alias = sqlUpdateStatement.getTableSource().getAlias();//表别名
        SQLExpr originCondition = sqlUpdateStatement.getWhere();
        String reallyFieldName = StringUtils.isBlank(alias) ? fieldName : alias + "." + fieldName;
        SQLExpr newSqlExpr = new SQLBinaryOpExpr(new SQLIdentifierExpr(reallyFieldName), new SQLCharExpr(fieldValue), SQLBinaryOperator.Equality);
        if (sqlUpdateStatement.getWhere() != null) {
            sqlUpdateStatement.setWhere(SQLUtils.buildCondition(SQLBinaryOperator.BooleanAnd, newSqlExpr, false, originCondition));
        } else {
            sqlUpdateStatement.setWhere(newSqlExpr);
        }
    }
}
